In [2]:
import pickle
import jax

import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [10]:
def plot_all(varient=''):
    all_pdfs = []
    all_labels = [] 
    all_pdfs_noise = []
    all_labels_noise = []
    varient= str(varient)
    x = jnp.linspace(0,6,10000)
    # with open('./results_data/linear_regression_noise_Ajax'+varient,'rb') as f:
    #     variational  = pickle.load(f)
    # params = variational.get_params()
    # loc_m, scale = jax.tree_leaves(variational.transform_dist(params['theta']))
    # scale = jnp.dot(scale, scale.T)
    # for i in range(2):
    #     y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
    #     all_pdfs.append(y)

    # all_labels.append('Ajax VI theta0')
    # all_labels.append('Ajax VI theta1')

    # with open('./results_data/linear_regression_laplace'+varient,'rb') as f:
    #     laplace = pickle.load(f)
    # loc_m = laplace['mean']
    # std = jnp.sqrt(jnp.diag(laplace['cov']))
    # for i in range(2):
    #     y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
    #     all_pdfs.append(y)
    # all_labels.append('Laplace approximation theta0')
    # all_labels.append('Laplace approximation theta1')

    with open('./results_data/MCMC_Blackjax'+varient,'rb') as f:
        black_samples = pickle.load(f)
    for i in range(2):
        kde_black = gaussian_kde(black_samples.position['theta'][:,i])
        pdf_black = kde_black(x)
        all_pdfs.append(pdf_black)
    
    

    
    kde_black = gaussian_kde(black_samples.position['noise_var'])
    pdf_black = kde_black(x)
    all_pdfs_noise.append(pdf_black)
    all_labels.append('Blackjax rmh theta0') 
    all_labels.append( 'Blackjax rmh theta1')
    all_labels_noise.append("Blackjax rmh noise")


    with open("./results_data/ajax_model"+varient,'rb') as f:
        posterior = pickle.load(f)

    samples_ajax= posterior.sample(seed = jax.random.PRNGKey(10), sample_shape = (10000,))
    for i in range(2):
        kde_ajax = gaussian_kde(samples_ajax["theta"][:,i])
        pdf_ajax = kde_ajax(x)
        all_pdfs.append(pdf_ajax)
    
    kde_ajax = gaussian_kde(samples_ajax["noise"])
    pdf_ajax = kde_ajax(x)
    all_pdfs_noise.append(pdf_ajax)

    all_labels.append("Ajax VI theta0")
    all_labels.append("Ajax VI theta1")
    all_labels_noise.append("Ajax VI noise")


    def create_df(all_pdfs,all_labels,x):
        all_pdfs = jnp.array(all_pdfs).reshape((-1))
        no_estimates = len(all_labels)
        all_labels_repeated = [item for item in all_labels for i in range(x.shape[0])]
        x_repeated = jnp.tile(x,no_estimates)
        to_df = {
            "theta":x_repeated,
            "PDF":all_pdfs,
            "label": all_labels_repeated

        }
        df = pd.DataFrame(to_df)
        return df
    df = create_df(all_pdfs,all_labels,x)



    fig = px.line(df,"theta","PDF",color="label",title=f"Linear regression posterior") 
    fig.show()

    df = create_df(all_pdfs_noise,all_labels_noise,x)

    fig = px.line(df,"theta","PDF",color="label",title=f"Linear regression posterior") 
    fig.show()
In [11]:
plot_all()
In [ ]:
!jupyter nbconvert --to HTML linear_regression_noise_results.ipynb